In [1]:
from os.path import basename

from toolz import merge
%cd /home/sitnarf/projects/cardiovascular-risk-app/backend

from sklearn.calibration import calibration_curve
from functional import pipe, or_fn

from functools import partial
from typing import Callable, Iterable, Mapping
from nested_cv import get_cv_results_from_simple_cv_evaluation

from notebooks.heart_transplant.dependencies.heart_transplant_functions import format_feature
from scripts.feature_importance import plot_feature_importance_formatted


from matplotlib import pyplot
from pandas import Series

from visualisation import plot_roc_from_result, display_html, savefig, display_dict_as_table_horizontal

%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

import itertools
import os
from pathlib import Path
for folder in itertools.chain([Path.cwd()], Path.cwd().parents):
    if (folder / 'Pipfile').exists():
        os.chdir(folder)
        break


from utils import evaluate_and_assign_if_not_present
from notebooks.heart_transplant.heart_transplant_evaluate import HEART_TRANSPLANT_CV_SHUFFLED_IDENTIFIER, \
    HEART_TRANSPLANT_EXPANDING_IDENTIFIER

import shelve
from notebooks.heart_transplant.dependencies.heart_transplant_metadata import heart_transplant_metadata as metadata

from evaluation_functions import compute_classification_metrics_from_results_with_statistics, join_folds_cv_result, \
    ModelResult, get_1_class_y_score
from formatting import compare_metrics_in_table,   b, render_struct_table
from notebooks.heart_transplant.dependencies.heart_transplant_data import get_reduced_binary_dataset_cached

pyplot.style.use('default')

PUBLISH_FOLDER  = './data/heart_transplant/publish'
/home/sitnarf/projects/cardiovascular-risk-app/backend
In [2]:
X_365, y_365, dataset_raw = get_reduced_binary_dataset_cached()

X_90, y_90, _ = get_reduced_binary_dataset_cached(survival_days=90)
In [3]:
def optimized_filter(key: str) -> bool:
    return key.endswith('optimized_roc')

def default_filter(key: str) -> bool:
    return key.endswith('default')

def expert_filter(key: str) -> bool:
    return 'expert' in key

def color_methods(key: str) -> Mapping:
    colors = pyplot.get_cmap('Paired').colors
    if 'xgboost' in key:
        index = 1
    elif 'random_forest' in key:
        index= 3
    elif 'ridge' in key:
        index = 5
    else:
        index = 1

    if 'default' in key:
        index -= 1

    return {'color': colors[index]}

def present_results(file_name: str, y: Series, include_ci=False, filter_callback: Callable[[str], bool] = None) -> None:
    metrics = {}

    results = shelve.open(file_name, flag='r')

    for name, item in results.items():
        if filter_callback is None or filter_callback(name):
            evaluate_and_assign_if_not_present(
                metrics,
                name,
                lambda: compute_classification_metrics_from_results_with_statistics(y, [item['chosen']['result']], threshold=0.5, ignore_warning=True)
            )

    pipe(
        compare_metrics_in_table(
            metrics,
            include=('roc_auc', 'f1', 'recall', 'fpr'),
            include_ci=include_ci,
        ),
        render_struct_table,
        display_html,
    )

    results.close()


def present_rocs(file_name: str, y: Series, filter_callback: Callable[[str], bool], style_by_callback: Iterable[Callable[[str], Mapping]] = None) -> None:
    style_by_callback = style_by_callback if style_by_callback is not None else {}

    results = shelve.open(file_name, flag='r')

    for method_name, item in results.items():
        if filter_callback is None or filter_callback(method_name):
            style = merge(*[callback(method_name) for callback in style_by_callback], {})
            plot_roc_from_result(y,item['chosen']['result'], label=method_name, plot_kwargs=style)

    savefig(PUBLISH_FOLDER+f'/rocs_{basename(file_name)}')

    pyplot.show()

    results.close()

Shuffled CV

In [4]:
present_results(HEART_TRANSPLANT_CV_SHUFFLED_IDENTIFIER, y_365)
print()
present_rocs(
    HEART_TRANSPLANT_CV_SHUFFLED_IDENTIFIER,
    y_365,
    filter_callback=or_fn(optimized_filter, default_filter),
    style_by_callback=[color_methods]
)
ROC/AUC Δ f1 Δ TPR Δ FPR Δ
random_forest_default 0.883 0.0 0.612 -0.015 0.442 -0.145 0.0 -0.345
xgboost_optimized_roc 0.86 -0.024 0.627 0.0 0.476 -0.111 0.006 -0.339
xgboost_optimized_roc_expert_features 0.856 -0.028 0.626 0.0 0.469 -0.118 0.004 -0.341
random_forest_optimized_roc 0.841 -0.042 0.582 -0.045 0.519 -0.068 0.038 -0.308
xgboost_default 0.772 -0.112 0.43 -0.197 0.574 -0.013 0.155 -0.19
ridge_optimized_roc 0.669 -0.214 0.291 -0.335 0.586 0.0 0.345 0.0
ridge_default 0.669 -0.214 0.292 -0.335 0.587 0.0 0.345 0.0
ridge_optimized_roc_expert_features 0.662 -0.221 0.289 -0.338 0.575 -0.012 0.341 -0.004

Expanding window

365 days

All age groups

In [5]:
present_results(HEART_TRANSPLANT_EXPANDING_IDENTIFIER+'_365_all', y_365, filter_callback=or_fn(optimized_filter, default_filter, expert_filter))
print()
present_rocs(
    HEART_TRANSPLANT_EXPANDING_IDENTIFIER+'_365_all',
    y_365,
    filter_callback=or_fn(optimized_filter, default_filter),
    style_by_callback=[color_methods]
)
ROC/AUC Δ f1 Δ TPR Δ FPR Δ
xgboost_optimized_roc 0.672 0.0 0.275 0.0 0.533 -0.042 0.297 -0.041
xgboost_optimized_roc_expert_features 0.67 -0.002 0.271 -0.004 0.575 0.0 0.338 0.0
random_forest_optimized_roc 0.669 -0.003 0.219 -0.056 0.178 -0.398 0.054 -0.284
random_forest_optimized_roc_expert_features 0.668 -0.004 0.232 -0.043 0.202 -0.373 0.066 -0.272
ridge_optimized_roc 0.657 -0.015 0.274 -0.001 0.471 -0.105 0.247 -0.091
ridge_default 0.657 -0.015 0.272 -0.003 0.477 -0.098 0.256 -0.082
ridge_optimized_roc_expert_features 0.654 -0.018 0.27 -0.005 0.485 -0.09 0.264 -0.074
random_forest_default 0.636 -0.036 0.003 -0.272 0.003 -0.573 0.0 -0.338
xgboost_default 0.595 -0.077 0.205 -0.07 0.218 -0.358 0.112 -0.226

< 18 years old

In [6]:
present_results(HEART_TRANSPLANT_EXPANDING_IDENTIFIER+'_365_l_18', y_365, filter_callback=or_fn(optimized_filter, default_filter))
print()
present_rocs(
    HEART_TRANSPLANT_EXPANDING_IDENTIFIER+'_365_l_18',
    y_365,
    filter_callback=or_fn(optimized_filter, default_filter),
    style_by_callback=[color_methods]
)
ROC/AUC Δ f1 Δ TPR Δ FPR Δ
random_forest_optimized_roc 0.751 0.0 0.171 -0.139 0.111 -0.419 0.016 -0.179
xgboost_optimized_roc 0.742 -0.009 0.261 -0.049 0.225 -0.305 0.048 -0.147
ridge_optimized_roc 0.735 -0.015 0.31 0.0 0.53 0.0 0.195 0.0
random_forest_default 0.732 -0.019 0.034 -0.276 0.03 -0.5 0.001 -0.194
ridge_default 0.72 -0.031 0.294 -0.016 0.481 -0.049 0.187 -0.008
xgboost_default 0.676 -0.075 0.19 -0.12 0.147 -0.383 0.033 -0.162

>= 18 years old

In [7]:
present_results(HEART_TRANSPLANT_EXPANDING_IDENTIFIER+'_365_me_18', y_365, filter_callback=or_fn(optimized_filter, default_filter))

present_rocs(
    HEART_TRANSPLANT_EXPANDING_IDENTIFIER+'_365_me_18',
    y_365,
    filter_callback=or_fn(optimized_filter, default_filter),
    style_by_callback=[color_methods]
)
ROC/AUC Δ f1 Δ TPR Δ FPR Δ
xgboost_optimized_roc 0.664 0.0 0.272 0.0 0.556 0.0 0.326 0.0
random_forest_optimized_roc 0.655 -0.008 0.225 -0.046 0.205 -0.351 0.075 -0.251
ridge_optimized_roc 0.647 -0.016 0.046 -0.226 0.025 -0.531 0.002 -0.324
ridge_default 0.643 -0.021 0.263 -0.009 0.482 -0.074 0.278 -0.048
random_forest_default 0.623 -0.041 0.001 -0.27 0.003 -0.553 0.0 -0.326
xgboost_default 0.583 -0.08 0.183 -0.089 0.189 -0.367 0.107 -0.219

90 days

In [8]:
present_results(HEART_TRANSPLANT_EXPANDING_IDENTIFIER+"_90_all", y_90, filter_callback=or_fn(optimized_filter, default_filter))
print()
present_rocs(
    HEART_TRANSPLANT_EXPANDING_IDENTIFIER+'_90_all',
    y_90,
    filter_callback=or_fn(optimized_filter, default_filter),
    style_by_callback=[color_methods]
)
ROC/AUC Δ f1 Δ TPR Δ FPR Δ
random_forest_optimized_roc 0.701 0.0 0.175 -0.041 0.132 -0.365 0.026 -0.206
xgboost_optimized_roc 0.696 -0.005 0.15 -0.066 0.105 -0.392 0.018 -0.214
ridge_optimized_roc 0.688 -0.014 0.216 -0.001 0.495 -0.002 0.232 0.0
ridge_default 0.687 -0.014 0.216 0.0 0.497 0.0 0.232 0.0
random_forest_default 0.659 -0.042 0.0 -0.216 0.003 -0.494 0.0 -0.232
xgboost_default 0.597 -0.104 0.138 -0.078 0.141 -0.356 0.064 -0.168

Calibration plots

Shuffled CV

In [9]:
def present_calibration_plots(
        file_name: str,
        y: Series,
        filter_callback: Callable[[str], bool] = None,
        style_by_callback: Iterable[Callable[[str], Mapping]] = None
) -> None:
    results = shelve.open(file_name, flag='r')
    fig = pyplot.figure(1, figsize=(10, 10))
    ax1 = pyplot.subplot2grid((3, 1), (0, 0), rowspan=2)
    ax2 = pyplot.subplot2grid((3, 1), (2, 0))
    ax1.plot([0, 1], [0, 1], color='gray', lw=1, linestyle='--')

    for method, item in sorted(results.items(), key=lambda i: i[0]):
        if filter_callback is None or filter_callback(method):
            style = merge(*[callback(method) for callback in style_by_callback], {})
            result_joined: ModelResult = join_folds_cv_result(item['chosen']['result'])
            fraction_of_positives, mean_predicted_value  = calibration_curve(
                y.loc[result_joined['y_test_score'].index], get_1_class_y_score(result_joined['y_test_score']),
                n_bins=20
            )
            ax1.plot(mean_predicted_value, fraction_of_positives, "s-", label=method, **style)

    ax1.legend(loc="lower right")

    for method, item in results.items():
        if filter_callback is None or filter_callback(method):
            result_joined: ModelResult = join_folds_cv_result(item['chosen']['result'])
            style = merge(*[callback(method) for callback in style_by_callback], {})
            ax2.hist(
                get_1_class_y_score(result_joined['y_test_score']),
                bins=20,
                label=method,
                histtype="step",
                lw=2,
                **style
            )
    savefig(PUBLISH_FOLDER+f'/calibration/{basename(file_name)}')
    fig.show()

    results.close()

present_calibration_plots(HEART_TRANSPLANT_CV_SHUFFLED_IDENTIFIER, y_365, filter_callback=or_fn(optimized_filter, default_filter), style_by_callback=[color_methods])

Expanding window

365 days

In [10]:
present_calibration_plots(HEART_TRANSPLANT_EXPANDING_IDENTIFIER+'_365_all', y_365, filter_callback=or_fn(optimized_filter, default_filter), style_by_callback=[color_methods])

All age groups

< 18 years old

In [11]:
present_calibration_plots(HEART_TRANSPLANT_EXPANDING_IDENTIFIER+'_365_l_18', y_365, filter_callback=or_fn(optimized_filter, default_filter), style_by_callback=[color_methods])

>= 18 years old

In [12]:
present_calibration_plots(HEART_TRANSPLANT_EXPANDING_IDENTIFIER+'_365_me_18', y_365, filter_callback=or_fn(optimized_filter, default_filter), style_by_callback=[color_methods])
In [101]:
 

90 days

In [13]:
present_calibration_plots(HEART_TRANSPLANT_EXPANDING_IDENTIFIER+'_90_all', y_90, filter_callback=or_fn(optimized_filter, default_filter), style_by_callback=[color_methods])

Feature importance

In [14]:
def present_feature_importance(file_name: str, filter_callback: Callable[[str], bool] = None) -> None:

    results = shelve.open(file_name, flag='r')
    for name, item in results.items():

        if filter_callback is None or filter_callback(name):
            pyplot.figure(figsize=(10, 10))
            pyplot.title(name)
            plot_feature_importance_formatted(get_cv_results_from_simple_cv_evaluation([item]),format_feature=partial(format_feature, metadata), n_features=20)
            pyplot.tight_layout()
            savefig(PUBLISH_FOLDER+f'/feature_importance/{name}')
            pyplot.show()

Shuffled CV

In [15]:
present_feature_importance(HEART_TRANSPLANT_CV_SHUFFLED_IDENTIFIER, filter_callback=optimized_filter)

Expanding window

365 days

All age groups

In [16]:
present_feature_importance(HEART_TRANSPLANT_EXPANDING_IDENTIFIER+'_365_all', filter_callback=optimized_filter)

< 18 years old

In [17]:
present_feature_importance(HEART_TRANSPLANT_EXPANDING_IDENTIFIER+'_365_l_18', filter_callback=optimized_filter)

>= 18 years old

In [18]:
present_feature_importance(HEART_TRANSPLANT_EXPANDING_IDENTIFIER+'_365_me_18', filter_callback=optimized_filter)

90 days

In [19]:
present_feature_importance(HEART_TRANSPLANT_EXPANDING_IDENTIFIER+"_90_all", filter_callback=optimized_filter)

Chosen hyperparameters

In [20]:
def present_chosen_hyperparameters(file_name: str, filter_callback: Callable = None, pretty: bool = True) -> None:

    results = shelve.open(file_name, flag='r')

    for name, item in results.items():
        if filter_callback is None or filter_callback(name) is True:
            b(name)
            try:
                hyperparameters = item['chosen']['configuration']['classifier']
            except KeyError:
                display_html('<p>No configuration</p>')
            else:
                if pretty:
                    display_dict_as_table_horizontal(hyperparameters)
                else:
                    print(hyperparameters)

            print()

Shuffled CV

In [21]:
present_chosen_hyperparameters(HEART_TRANSPLANT_CV_SHUFFLED_IDENTIFIER, filter_callback=optimized_filter, pretty=True)
ridge_optimized_roc
C
0.1

xgboost_optimized_roc
colsample_bytree gamma learning_rate max_depth min_child_weight n_estimators subsample
0.5203627244512031 1.0332652273520901 0.029104815865875405 16 2.0 195 0.9364481745668485

random_forest_optimized_roc
bootstrap max_depth max_features min_samples_leaf min_samples_split n_estimators
False 16 auto 8 18 430

Expanding window

365 days

All age groups

In [22]:
present_chosen_hyperparameters(HEART_TRANSPLANT_EXPANDING_IDENTIFIER+'_365_all', filter_callback=optimized_filter, pretty=True)
ridge_optimized_roc
C
0.001

xgboost_optimized_roc
colsample_bytree gamma learning_rate max_depth min_child_weight n_estimators subsample
0.5338977067906682 1.4443844778754271 0.0453271241042913 3 8.0 190 0.38420232006551414

random_forest_optimized_roc
bootstrap max_depth max_features min_samples_leaf min_samples_split n_estimators
True 15 auto 20 10 447

< 18 years old

In [23]:
present_chosen_hyperparameters(HEART_TRANSPLANT_EXPANDING_IDENTIFIER+'_365_l_18', filter_callback=optimized_filter, pretty=True)
ridge_optimized_roc
C
0.001

xgboost_optimized_roc
colsample_bytree gamma learning_rate max_depth min_child_weight n_estimators subsample
0.3554189414042873 4.126603155997045 0.012727237986869464 15 4.0 70 0.9657251999688914

random_forest_optimized_roc
bootstrap max_depth max_features min_samples_leaf min_samples_split n_estimators
True 16 log2 3 18 491

>= 18 years old

In [24]:
present_chosen_hyperparameters(HEART_TRANSPLANT_EXPANDING_IDENTIFIER+'_365_me_18', filter_callback=optimized_filter, pretty=True)
ridge_optimized_roc
C
0.001

xgboost_optimized_roc
colsample_bytree gamma learning_rate max_depth min_child_weight n_estimators subsample
0.6598938205962753 3.9480532042454186 0.10799167560301227 2 5.0 120 0.5589408638569416

random_forest_optimized_roc
bootstrap max_depth max_features min_samples_leaf min_samples_split n_estimators
True 11 log2 19 9 302

90 days

In [25]:
present_chosen_hyperparameters(HEART_TRANSPLANT_EXPANDING_IDENTIFIER+'_90_all', filter_callback=optimized_filter, pretty=True)
ridge_optimized_roc
C
0.1

xgboost_optimized_roc
colsample_bytree gamma learning_rate max_depth min_child_weight n_estimators subsample
0.4820801782317513 4.359158362528796 0.012198380051845176 16 7.0 195 0.42673970439046754

random_forest_optimized_roc
bootstrap max_depth max_features min_samples_leaf min_samples_split n_estimators
True 16 auto 20 16 485